{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5f2e91b4-b880-46a9-94be-8a85dd99a139",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ast\n",
    "from openai import OpenAI\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from transformers import pipeline\n",
    "import pycountry\n",
    "from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import pycountry\n",
    "import logging\n",
    "import country_converter as coco\n",
    "\n",
    "# Apply the function to the DataFrame with tqdm\n",
    "tqdm.pandas()  # Initialize tqdm for pandas\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "9eaf5339-fa8f-4657-942b-66ba2d2cdf02",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "client = OpenAI(api_key=api_key)\n",
    "def ask_gpt(abstract):\n",
    "    response = client.chat.completions.create(\n",
    "      model=\"gpt-4o\",\n",
    "      messages=[\n",
    "        {\n",
    "          \"role\": \"user\",\n",
    "          \"content\": [\n",
    "            {\n",
    "              \"type\": \"text\",\n",
    "              \"text\": f\"% As a Name entity recognition tool, you are responsible for analyzing academic texts from various sources. Your task is to find all the country names mentioned in the abstract\\\\\\\\\\n% Constraint: Answer with only the python list of countries in iso3 format \\nmentioned in the text that is most accurate and nothing else.\\\\\\\\ \\\\\\\\\\n% Abstract:{abstract} \"\n",
    "            }\n",
    "          ]\n",
    "        }\n",
    "    \n",
    "          ]\n",
    "        ,\n",
    "      temperature=0.3,\n",
    "      max_tokens=2048,\n",
    "      top_p=0.5,\n",
    "      frequency_penalty=0,\n",
    "      presence_penalty=0,\n",
    "      response_format={\n",
    "        \"type\": \"text\"\n",
    "      }\n",
    "    )\n",
    "        # Strip the backticks and language identifier\n",
    "    stripped = response.choices[0].message.content.strip(\"```\").replace(\"python\\n\", \"\")\n",
    "    \n",
    "    # Convert the string to a Python object\n",
    "    \n",
    "    result = ast.literal_eval(stripped)\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b2030eeb-b0ea-4cdb-aa3d-ba05a57eff60",
   "metadata": {},
   "outputs": [],
   "source": [
    "df=pd.read_csv(r'C:\\Users\\Yasaman\\Arab_spring_scholarly_attention\\Validating\\final_annotated.csv')\n",
    "df['Text']=df['Title']+' '+df['Abstract']\n",
    "df['union_annotation'] = df['union_annotation'].apply(lambda x: eval(x) if isinstance(x, str) else x)\n",
    "df['intersection_annotation'] = df['intersection_annotation'].apply(lambda x: eval(x) if isinstance(x, str) else x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f65a5811-96d1-47c7-bf3a-51fe81fc93a2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [11:06<00:00,  1.50it/s]\n"
     ]
    }
   ],
   "source": [
    "df[\"locations\"] = df[\"Text\"].progress_apply(ask_gpt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7925d3e1-5de9-43d3-966b-1181353bdd63",
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"locations\"]=df[\"locations\"].apply(lambda x:[y.lower() for y in x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "87c2865a-dca6-4df5-9570-2b189fec501a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "overall accuracy (union) 0.812\n",
      "overall accuracy (intersection) 0.78\n"
     ]
    }
   ],
   "source": [
    "print('overall accuracy (union)', sum(df['locations']==df['union_annotation'])/1000)\n",
    "print('overall accuracy (intersection)', sum(df['locations']==df['intersection_annotation'])/1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "285c4414-1b47-4ee1-bdc3-46cc6dd22f49",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Yasaman\\AppData\\Local\\Temp\\ipykernel_33152\\2343847047.py:1: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
      "  sample_accuracies = df.groupby(\"SampleGroup\").apply(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>SampleGroup</th>\n",
       "      <th>accuracy_union</th>\n",
       "      <th>accuracy_intersection</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>with_mention</td>\n",
       "      <td>0.803571</td>\n",
       "      <td>0.728571</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>with_mention_arab</td>\n",
       "      <td>0.840000</td>\n",
       "      <td>0.790000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>field_20</td>\n",
       "      <td>0.930769</td>\n",
       "      <td>0.901923</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         SampleGroup  accuracy_union  accuracy_intersection\n",
       "0       with_mention        0.803571               0.728571\n",
       "1  with_mention_arab        0.840000               0.790000\n",
       "2           field_20        0.930769               0.901923"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample_accuracies = df.groupby(\"SampleGroup\").apply(\n",
    "    lambda x: pd.Series({\n",
    "        \"accuracy_union\": (x.apply(lambda row: set(row['locations']) == set(row['union_annotation']), axis=1).mean()),\n",
    "        \"accuracy_intersection\": (x.apply(lambda row: set(row['locations']) == set(row['intersection_annotation']), axis=1).mean())\n",
    "    })\n",
    ").reset_index()\n",
    "\n",
    "sample_accuracies = sample_accuracies.sort_values(by='accuracy_union').reset_index(drop=True)\n",
    "sample_accuracies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "304f40c8-2539-48e2-a4db-d66a8ddeb014",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Yasaman\\AppData\\Local\\Temp\\ipykernel_33152\\388812476.py:1: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
      "  sample_accuracies = df.groupby(\"SampleGroup\").apply(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>SampleGroup</th>\n",
       "      <th>jaccard_union</th>\n",
       "      <th>jaccard_intersection</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>with_mention</td>\n",
       "      <td>0.845103</td>\n",
       "      <td>0.778469</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>with_mention_arab</td>\n",
       "      <td>0.913708</td>\n",
       "      <td>0.855610</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>field_20</td>\n",
       "      <td>0.934982</td>\n",
       "      <td>0.907418</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         SampleGroup  jaccard_union  jaccard_intersection\n",
       "0       with_mention       0.845103              0.778469\n",
       "1  with_mention_arab       0.913708              0.855610\n",
       "2           field_20       0.934982              0.907418"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample_accuracies = df.groupby(\"SampleGroup\").apply(\n",
    "    lambda x: pd.Series({\n",
    "        \"jaccard_union\": (\n",
    "            x.apply(\n",
    "                lambda row: len(set(row['locations']).intersection(set(row['union_annotation']))) /\n",
    "                            len(set(row['locations']).union(set(row['union_annotation'])))\n",
    "                if len(set(row['locations']).union(set(row['union_annotation']))) > 0 else 1,\n",
    "                axis=1\n",
    "            ).mean()\n",
    "        ),\n",
    "        \"jaccard_intersection\": (\n",
    "            x.apply(\n",
    "                lambda row: len(set(row['locations']).intersection(set(row['intersection_annotation']))) /\n",
    "                            len(set(row['locations']).union(set(row['intersection_annotation'])))\n",
    "                if len(set(row['locations']).union(set(row['intersection_annotation']))) > 0 else 1,\n",
    "                axis=1\n",
    "            ).mean()\n",
    "        )\n",
    "    })\n",
    ").reset_index()\n",
    "\n",
    "sample_accuracies = sample_accuracies.sort_values(by='jaccard_union').reset_index(drop=True)\n",
    "sample_accuracies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "dbfa5643-6067-4946-9a02-487f97a91b7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv(r'C:\\Users\\Yasaman\\Arab_spring_scholarly_attention\\Validating\\gpt.csv', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
